# SPDX-License-Identifier: LGPL-3.0-or-later
import dpdata
import numpy as np

from deepmd.tf.descriptor import (
    DescrptSeA,
)
from deepmd.tf.env import (
    tf,
)
from deepmd.tf.fit import (
    EnerFitting,
)
from deepmd.tf.model import (
    EnerModel,
)
from deepmd.tf.utils.type_embed import (
    TypeEmbedNet,
)

from .common import (
    DataSystem,
    del_data,
    gen_data,
    j_loader,
)

GLOBAL_ENER_FLOAT_PRECISION = tf.float64
GLOBAL_TF_FLOAT_PRECISION = tf.float64
GLOBAL_NP_FLOAT_PRECISION = np.float64


class TestModel(tf.test.TestCase):
    def setUp(self) -> None:
        gen_data()

    def tearDown(self) -> None:
        del_data()

    def test_model_atom_ener(self) -> None:
        jfile = "water_se_a.json"
        jdata = j_loader(jfile)
        set_atom_ener = [0.02, 0.01]
        jdata["model"]["fitting_net"]["atom_ener"] = set_atom_ener

        sys = dpdata.LabeledSystem()
        sys.data["atom_names"] = ["foo", "bar"]
        sys.data["coords"] = np.array([0, 0, 0, 0, 0, 0])
        sys.data["atom_types"] = [0]
        sys.data["cells"] = np.array([np.eye(3) * 30, np.eye(3) * 30])
        nframes = 2
        natoms = 1
        sys.data["coords"] = sys.data["coords"].reshape([nframes, natoms, 3])
        sys.data["cells"] = sys.data["cells"].reshape([nframes, 3, 3])
        sys.data["energies"] = np.zeros([nframes, 1])
        sys.data["forces"] = np.zeros([nframes, natoms, 3])
        sys.to_deepmd_npy("system", prec=np.float64)

        systems = jdata["systems"]
        set_pfx = "set"
        batch_size = jdata["batch_size"]
        test_size = jdata["numb_test"]
        batch_size = 1
        test_size = 1
        rcut = jdata["model"]["descriptor"]["rcut"]

        data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)
        test_data = data.get_test()
        numb_test = 1

        jdata["model"]["descriptor"].pop("type", None)
        descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
        jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
        jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
        jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
        fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
        model = EnerModel(descrpt, fitting)

        test_data["natoms_vec"] = [1, 1, 1, 0]

        input_data = {
            "coord": [test_data["coord"]],
            "box": [test_data["box"]],
            "type": [test_data["type"]],
            "natoms_vec": [test_data["natoms_vec"]],
            "default_mesh": [test_data["default_mesh"]],
        }
        model._compute_input_stat(input_data)
        model.fitting.bias_atom_e = np.array(set_atom_ener)

        t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
        t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy")
        t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
        t_type = tf.placeholder(tf.int32, [None], name="i_type")
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms")
        t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box")
        t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
        is_training = tf.placeholder(tf.bool)
        t_fparam = None

        model_pred = model.build(
            t_coord,
            t_type,
            t_natoms,
            t_box,
            t_mesh,
            t_fparam,
            suffix="se_a_atom_ener_0",
            reuse=False,
        )
        energy = model_pred["energy"]
        force = model_pred["force"]
        virial = model_pred["virial"]

        feed_dict_test = {
            t_prop_c: test_data["prop_c"],
            t_energy: test_data["energy"][:numb_test],
            t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]),
            t_box: test_data["box"][:numb_test, :],
            t_type: np.reshape([0], [-1]),
            t_natoms: [1, 1, 1, 0],
            t_mesh: test_data["default_mesh"],
            is_training: False,
        }
        sess = self.cached_session().__enter__()
        sess.run(tf.global_variables_initializer())
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
        self.assertAlmostEqual(e[0], set_atom_ener[0], places=10)

        feed_dict_test[t_type] = np.reshape([1], [-1])
        feed_dict_test[t_natoms] = [1, 1, 0, 1]
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
        self.assertAlmostEqual(e[0], set_atom_ener[1], places=10)

    def test_model(self) -> None:
        jfile = "water_se_a.json"
        jdata = j_loader(jfile)

        systems = jdata["systems"]
        set_pfx = "set"
        batch_size = jdata["batch_size"]
        test_size = jdata["numb_test"]
        batch_size = 1
        test_size = 1
        rcut = jdata["model"]["descriptor"]["rcut"]

        data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)

        test_data = data.get_test()
        numb_test = 1

        jdata["model"]["descriptor"].pop("type", None)
        descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
        jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
        jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
        jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
        fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
        model = EnerModel(descrpt, fitting)

        # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
        input_data = {
            "coord": [test_data["coord"]],
            "box": [test_data["box"]],
            "type": [test_data["type"]],
            "natoms_vec": [test_data["natoms_vec"]],
            "default_mesh": [test_data["default_mesh"]],
        }
        model._compute_input_stat(input_data)
        model.descrpt.bias_atom_e = data.compute_energy_shift()

        t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
        t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy")
        t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force")
        t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial")
        t_atom_ener = tf.placeholder(
            GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener"
        )
        t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
        t_type = tf.placeholder(tf.int32, [None], name="i_type")
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms")
        t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box")
        t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
        is_training = tf.placeholder(tf.bool)
        t_fparam = None

        model_pred = model.build(
            t_coord,
            t_type,
            t_natoms,
            t_box,
            t_mesh,
            t_fparam,
            suffix="se_a",
            reuse=False,
        )
        energy = model_pred["energy"]
        force = model_pred["force"]
        virial = model_pred["virial"]
        atom_ener = model_pred["atom_ener"]

        feed_dict_test = {
            t_prop_c: test_data["prop_c"],
            t_energy: test_data["energy"][:numb_test],
            t_force: np.reshape(test_data["force"][:numb_test, :], [-1]),
            t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]),
            t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]),
            t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]),
            t_box: test_data["box"][:numb_test, :],
            t_type: np.reshape(test_data["type"][:numb_test, :], [-1]),
            t_natoms: test_data["natoms_vec"],
            t_mesh: test_data["default_mesh"],
            is_training: False,
        }

        sess = self.cached_session().__enter__()
        sess.run(tf.global_variables_initializer())
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)

        e = e.reshape([-1])
        f = f.reshape([-1])
        v = v.reshape([-1])
        refe = [6.135449167779321300e01]
        reff = [
            7.799691562262310585e-02,
            9.423098804815030483e-02,
            3.790560997388224204e-03,
            1.432522403799846578e-01,
            1.148392791403983204e-01,
            -1.321871172563671148e-02,
            -7.318966526325138000e-02,
            6.516069212737778116e-02,
            5.406418483320515412e-04,
            5.870713761026503247e-02,
            -1.605402669549013672e-01,
            -5.089516979826595386e-03,
            -2.554593467731766654e-01,
            3.092063507347833987e-02,
            1.510355029451411479e-02,
            4.869271842355533952e-02,
            -1.446113274345035005e-01,
            -1.126524434771078789e-03,
        ]
        refv = [
            -6.076776685178300053e-01,
            1.103174323630009418e-01,
            1.984250991380156690e-02,
            1.103174323630009557e-01,
            -3.319759402259439551e-01,
            -6.007404107650986258e-03,
            1.984250991380157036e-02,
            -6.007404107650981921e-03,
            -1.200076017439753642e-03,
        ]
        refe = np.reshape(refe, [-1])
        reff = np.reshape(reff, [-1])
        refv = np.reshape(refv, [-1])

        places = 10
        np.testing.assert_almost_equal(e, refe, places)
        np.testing.assert_almost_equal(f, reff, places)
        np.testing.assert_almost_equal(v, refv, places)

        # test input requirement for the model
        self.assertCountEqual(model.input_requirement, [])

    def test_model_atom_ener_type_embedding(self) -> None:
        """Test atom ener with type embedding."""
        jfile = "water_se_a.json"
        jdata = j_loader(jfile)
        set_atom_ener = [0.02, 0.01]
        jdata["model"]["fitting_net"]["atom_ener"] = set_atom_ener
        jdata["model"]["type_embeding"] = {"neuron": [2]}

        sys = dpdata.LabeledSystem()
        sys.data["atom_names"] = ["foo", "bar"]
        sys.data["coords"] = np.array([0, 0, 0, 0, 0, 0])
        sys.data["atom_types"] = [0]
        sys.data["cells"] = np.array([np.eye(3) * 30, np.eye(3) * 30])
        nframes = 2
        natoms = 1
        sys.data["coords"] = sys.data["coords"].reshape([nframes, natoms, 3])
        sys.data["cells"] = sys.data["cells"].reshape([nframes, 3, 3])
        sys.data["energies"] = np.zeros([nframes, 1])
        sys.data["forces"] = np.zeros([nframes, natoms, 3])
        sys.to_deepmd_npy("system", prec=np.float64)

        systems = jdata["systems"]
        set_pfx = "set"
        batch_size = jdata["batch_size"]
        test_size = jdata["numb_test"]
        batch_size = 1
        test_size = 1
        rcut = jdata["model"]["descriptor"]["rcut"]

        data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)
        test_data = data.get_test()
        numb_test = 1

        typeebd = TypeEmbedNet(
            ntypes=len(jdata["model"]["descriptor"]["sel"]),
            **jdata["model"]["type_embeding"],
            use_tebd_bias=True,
        )
        jdata["model"]["descriptor"].pop("type", None)
        descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True)
        jdata["model"]["fitting_net"]["ntypes"] = descrpt.get_ntypes()
        jdata["model"]["fitting_net"]["dim_descrpt"] = descrpt.get_dim_out()
        jdata["model"]["fitting_net"]["dim_rot_mat_1"] = descrpt.get_dim_rot_mat_1()
        fitting = EnerFitting(**jdata["model"]["fitting_net"], uniform_seed=True)
        model = EnerModel(descrpt, fitting, typeebd=typeebd)

        test_data["natoms_vec"] = [1, 1, 1, 0]

        input_data = {
            "coord": [test_data["coord"]],
            "box": [test_data["box"]],
            "type": [test_data["type"]],
            "natoms_vec": [test_data["natoms_vec"]],
            "default_mesh": [test_data["default_mesh"]],
        }
        model._compute_input_stat(input_data)
        model.fitting.bias_atom_e = np.array(set_atom_ener)

        t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
        t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy")
        t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
        t_type = tf.placeholder(tf.int32, [None], name="i_type")
        t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms")
        t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box")
        t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
        is_training = tf.placeholder(tf.bool)
        t_fparam = None

        model_pred = model.build(
            t_coord,
            t_type,
            t_natoms,
            t_box,
            t_mesh,
            t_fparam,
            suffix="se_a_atom_ener_type_embbed_0",
            reuse=False,
        )
        energy = model_pred["energy"]
        force = model_pred["force"]
        virial = model_pred["virial"]

        feed_dict_test = {
            t_prop_c: test_data["prop_c"],
            t_energy: test_data["energy"][:numb_test],
            t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]),
            t_box: test_data["box"][:numb_test, :],
            t_type: np.reshape([0], [-1]),
            t_natoms: [1, 1, 1, 0],
            t_mesh: test_data["default_mesh"],
            is_training: False,
        }
        sess = self.cached_session().__enter__()
        sess.run(tf.global_variables_initializer())
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
        self.assertAlmostEqual(e[0], set_atom_ener[0], places=10)

        feed_dict_test[t_type] = np.reshape([1], [-1])
        feed_dict_test[t_natoms] = [1, 1, 0, 1]
        [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
        self.assertAlmostEqual(e[0], set_atom_ener[1], places=10)
